{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Using Large Datasets\n",
    "> Tutorial on how to train neuralforecast models on datasets that cannot fit into memory"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The standard DataLoader class used by NeuralForecast expects the dataset to be represented by a single DataFrame, which is entirely loaded into memory when fitting the model. However, when the dataset is too large for this, we can instead use the custom large-scale DataLoader. This custom loader assumes that each timeseries is split across a collection of Parquet files, and ensure that only one batch is ever loaded into memory at a given time.\n",
    "\n",
    "In this notebook, we will demonstrate the expected format of these files, how to train the model and and how to perform inference using this large-scale DataLoader."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import logging\n",
    "import os\n",
    "import tempfile\n",
    "\n",
    "import pandas as pd\n",
    "\n",
    "from neuralforecast import NeuralForecast\n",
    "from neuralforecast.models import NHITS\n",
    "from utilsforecast.evaluation import evaluate\n",
    "from utilsforecast.losses import mae, rmse, smape\n",
    "from neuralforecast.utils import AirPassengersPanel, AirPassengersStatic"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "logging.getLogger('pytorch_lightning').setLevel(logging.ERROR)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data\n",
    "\n",
    "Each timeseries should be stored in a directory named **unique_id=timeseries_id**. Within this directory, the timeseries can be entirely contained in a single Parquet file or split across multiple Parquet files. Regardless of the format, the timeseries must be ordered by time.\n",
    "\n",
    "For example, the following code splits the AirPassengers DataFrame (of which each timeseries is already sorted by time) into the below format:\n",
    "<br>\n",
    "<br>\n",
    "\n",
    "**\\>**&nbsp;&nbsp;data  \n",
    "&nbsp;&nbsp;&nbsp;&nbsp;**\\>**&nbsp;&nbsp;unique_id=Airline1  \n",
    "&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; -&nbsp;&nbsp;a59945617fdb40d1bc6caa4aadad881c-0.parquet  \n",
    "&nbsp;&nbsp;&nbsp;&nbsp;**\\>**&nbsp;&nbsp;unique_id=Airline2  \n",
    "&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp; -&nbsp;&nbsp;a59945617fdb40d1bc6caa4aadad881c-0.parquet  \n",
    "\n",
    "<br>\n",
    "We then simply input a list of the paths to these directories."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>unique_id</th>\n",
       "      <th>ds</th>\n",
       "      <th>y</th>\n",
       "      <th>trend</th>\n",
       "      <th>y_[lag12]</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Airline1</td>\n",
       "      <td>1949-01-31</td>\n",
       "      <td>112.0</td>\n",
       "      <td>0</td>\n",
       "      <td>112.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Airline1</td>\n",
       "      <td>1949-02-28</td>\n",
       "      <td>118.0</td>\n",
       "      <td>1</td>\n",
       "      <td>118.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>Airline1</td>\n",
       "      <td>1949-03-31</td>\n",
       "      <td>132.0</td>\n",
       "      <td>2</td>\n",
       "      <td>132.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>Airline1</td>\n",
       "      <td>1949-04-30</td>\n",
       "      <td>129.0</td>\n",
       "      <td>3</td>\n",
       "      <td>129.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>Airline1</td>\n",
       "      <td>1949-05-31</td>\n",
       "      <td>121.0</td>\n",
       "      <td>4</td>\n",
       "      <td>121.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>283</th>\n",
       "      <td>Airline2</td>\n",
       "      <td>1960-08-31</td>\n",
       "      <td>906.0</td>\n",
       "      <td>283</td>\n",
       "      <td>859.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>284</th>\n",
       "      <td>Airline2</td>\n",
       "      <td>1960-09-30</td>\n",
       "      <td>808.0</td>\n",
       "      <td>284</td>\n",
       "      <td>763.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>285</th>\n",
       "      <td>Airline2</td>\n",
       "      <td>1960-10-31</td>\n",
       "      <td>761.0</td>\n",
       "      <td>285</td>\n",
       "      <td>707.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>286</th>\n",
       "      <td>Airline2</td>\n",
       "      <td>1960-11-30</td>\n",
       "      <td>690.0</td>\n",
       "      <td>286</td>\n",
       "      <td>662.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>287</th>\n",
       "      <td>Airline2</td>\n",
       "      <td>1960-12-31</td>\n",
       "      <td>732.0</td>\n",
       "      <td>287</td>\n",
       "      <td>705.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>288 rows × 5 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "    unique_id         ds      y  trend  y_[lag12]\n",
       "0    Airline1 1949-01-31  112.0      0      112.0\n",
       "1    Airline1 1949-02-28  118.0      1      118.0\n",
       "2    Airline1 1949-03-31  132.0      2      132.0\n",
       "3    Airline1 1949-04-30  129.0      3      129.0\n",
       "4    Airline1 1949-05-31  121.0      4      121.0\n",
       "..        ...        ...    ...    ...        ...\n",
       "283  Airline2 1960-08-31  906.0    283      859.0\n",
       "284  Airline2 1960-09-30  808.0    284      763.0\n",
       "285  Airline2 1960-10-31  761.0    285      707.0\n",
       "286  Airline2 1960-11-30  690.0    286      662.0\n",
       "287  Airline2 1960-12-31  732.0    287      705.0\n",
       "\n",
       "[288 rows x 5 columns]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Y_df = AirPassengersPanel.copy()\n",
    "Y_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['C:\\\\Users\\\\ospra\\\\AppData\\\\Local\\\\Temp\\\\tmpxe__gjoo/unique_id=Airline1',\n",
       " 'C:\\\\Users\\\\ospra\\\\AppData\\\\Local\\\\Temp\\\\tmpxe__gjoo/unique_id=Airline2']"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "valid = Y_df.groupby('unique_id').tail(72)\n",
    "# from now on we will use the id_col as the unique identifier for the timeseries (this is because we are using the unique_id column to partition the data into parquet files)\n",
    "valid = valid.rename(columns={'unique_id': 'id_col'})\n",
    "\n",
    "train = Y_df.drop(valid.index)\n",
    "train['id_col'] = train['unique_id'].copy()\n",
    "\n",
    "# we generate the files using a temporary directory here to demonstrate the expected file structure\n",
    "tmpdir = tempfile.TemporaryDirectory()\n",
    "train.to_parquet(tmpdir.name, partition_cols=['unique_id'], index=False)\n",
    "files_list = [f\"{tmpdir.name}/{dir}\" for dir in os.listdir(tmpdir.name)]\n",
    "files_list"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can also create this directory structure with a spark dataframe using the following:"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```python\n",
    "spark.conf.set(\"spark.sql.parquet.outputTimestampType\", \"TIMESTAMP_MICROS\")\n",
    "(\n",
    "  spark_df\n",
    "  .repartition(id_col)\n",
    "  .sortWithinPartitions(id_col, time_col)\n",
    "  .write\n",
    "  .partitionBy(id_col)\n",
    "  .parquet(out_dir)\n",
    ")\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The DataLoader class still expects the static data to be passed in as a single DataFrame with one row per timeseries."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>id_col</th>\n",
       "      <th>airline1</th>\n",
       "      <th>airline2</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Airline1</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Airline2</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "     id_col  airline1  airline2\n",
       "0  Airline1         0         1\n",
       "1  Airline2         1         0"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "static = AirPassengersStatic.rename(columns={'unique_id': 'id_col'})\n",
    "static"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model training\n",
    "\n",
    "We now train a NHITS model on the above dataset. \n",
    "It is worth noting that NeuralForecast currently does not support scaling when using this DataLoader. If you want to scale the timeseries this should be done before passing it in to the `fit` method."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Seed set to 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "da7f05dbbb4444baaa3eb6d6a2b13c4a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Sanity Checking: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e28728371ccb4a21b266896fc5968956",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Training: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "dadb7357a3fc4e0e8086d7cac870bfc3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "horizon = 12\n",
    "stacks = 3\n",
    "models = [NHITS(input_size=5 * horizon,\n",
    "                h=horizon,\n",
    "                futr_exog_list=['trend', 'y_[lag12]'],\n",
    "                stat_exog_list=['airline1', 'airline2'],\n",
    "                max_steps=100,\n",
    "                stack_types = stacks*['identity'],\n",
    "                n_blocks = stacks*[1],\n",
    "                mlp_units = [[256,256] for _ in range(stacks)],\n",
    "                n_pool_kernel_size = stacks*[1],\n",
    "                interpolation_mode=\"nearest\")]\n",
    "nf = NeuralForecast(models=models, freq='ME')\n",
    "nf.fit(df=files_list, static_df=static, id_col='id_col')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Forecasting\n",
    "\n",
    "When working with large datasets, we need to provide a single DataFrame containing the input timesteps of all the timeseries for which wish to generate predictions. If we have future exogenous features, we should also include the future values of these features in the separate `futr_df` DataFrame. \n",
    "\n",
    "For the below prediction we are assuming we only want to predict the next 12 timesteps for Airline2."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f7e42194b6a74bdc8363e16a3c5f13a0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Predicting: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "valid_df = valid[valid['id_col'] == 'Airline2']\n",
    "# we set input_size=60 and horizon=12 when fitting the model\n",
    "pred_df = valid_df[:60]\n",
    "futr_df = valid_df[60:72]\n",
    "futr_df = futr_df.drop([\"y\"], axis=1)\n",
    "\n",
    "predictions = nf.predict(df=pred_df, futr_df=futr_df, static_df=static)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>id_col</th>\n",
       "      <th>ds</th>\n",
       "      <th>NHITS</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Airline2</td>\n",
       "      <td>1960-01-31</td>\n",
       "      <td>713.441406</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Airline2</td>\n",
       "      <td>1960-02-29</td>\n",
       "      <td>688.176880</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>Airline2</td>\n",
       "      <td>1960-03-31</td>\n",
       "      <td>763.382935</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>Airline2</td>\n",
       "      <td>1960-04-30</td>\n",
       "      <td>745.478027</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>Airline2</td>\n",
       "      <td>1960-05-31</td>\n",
       "      <td>758.036438</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>Airline2</td>\n",
       "      <td>1960-06-30</td>\n",
       "      <td>806.288574</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>Airline2</td>\n",
       "      <td>1960-07-31</td>\n",
       "      <td>869.563782</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>Airline2</td>\n",
       "      <td>1960-08-31</td>\n",
       "      <td>858.105896</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>Airline2</td>\n",
       "      <td>1960-09-30</td>\n",
       "      <td>803.531555</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>Airline2</td>\n",
       "      <td>1960-10-31</td>\n",
       "      <td>751.093079</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>Airline2</td>\n",
       "      <td>1960-11-30</td>\n",
       "      <td>700.435852</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>Airline2</td>\n",
       "      <td>1960-12-31</td>\n",
       "      <td>746.640259</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "      id_col         ds       NHITS\n",
       "0   Airline2 1960-01-31  713.441406\n",
       "1   Airline2 1960-02-29  688.176880\n",
       "2   Airline2 1960-03-31  763.382935\n",
       "3   Airline2 1960-04-30  745.478027\n",
       "4   Airline2 1960-05-31  758.036438\n",
       "5   Airline2 1960-06-30  806.288574\n",
       "6   Airline2 1960-07-31  869.563782\n",
       "7   Airline2 1960-08-31  858.105896\n",
       "8   Airline2 1960-09-30  803.531555\n",
       "9   Airline2 1960-10-31  751.093079\n",
       "10  Airline2 1960-11-30  700.435852\n",
       "11  Airline2 1960-12-31  746.640259"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "predictions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "target = valid_df[60:72]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>metric</th>\n",
       "      <th>NHITS</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>mae</td>\n",
       "      <td>20.728617</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>rmse</td>\n",
       "      <td>26.980698</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>smape</td>\n",
       "      <td>0.012879</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "  metric      NHITS\n",
       "0    mae  20.728617\n",
       "1   rmse  26.980698\n",
       "2  smape   0.012879"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "evaluate(\n",
    "    predictions.merge(target.drop([\"trend\", \"y_[lag12]\"], axis=1), on=['id_col', 'ds']),\n",
    "    metrics=[mae, rmse, smape],\n",
    "    id_col='id_col',\n",
    "    agg_fn='mean',\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
